from pathlib import Path

import yaml

from modules import utils
from modules.text_generation import get_encoded_length


def load_prompt(fname):
    if fname in ['None', '']:
        return ''
    else:
        file_path = Path(f'prompts/{fname}.txt')
        if not file_path.exists():
            return ''

        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
            if text[-1] == '\n':
                text = text[:-1]

            return text


def load_instruction_prompt_simple(fname):
    file_path = Path(f'instruction-templates/{fname}.yaml')
    if not file_path.exists():
        return ''

    with open(file_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
        output = ''
        if 'context' in data:
            output += data['context']

        replacements = {
            '<|user|>': data['user'],
            '<|bot|>': data['bot'],
            '<|user-message|>': 'Input',
        }

        output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
        return output.rstrip(' ')


def count_tokens(text):
    try:
        tokens = get_encoded_length(text)
        return str(tokens)
    except:
        return '0'
